# -*- coding: utf-8 -*-
"""
@date: 2020/12/21 20:13
@file: input_data.py
@author: lilong
@desc: 输入到bert最原始的数据的形式
"""

from torch.utils.data import DataLoader
from bert_torch.dataset.vocab import TorchVocab
from bert_torch.dataset import BERTDataset, WordVocab


def main():
    vocab_path = "../../pretrainedModel/chinese_L-12_H-768_A-12/tt_vocab.txt"

    # 加载词典
    print("Loading Vocab:", vocab_path)
    vocab = WordVocab.load_vocab(vocab_path)
    print("Vocab Size: ", len(vocab))
    print("vocab:", vocab)

    # vocab是词典对象
    vocab = WordVocab(vocab)

    # 加载训练数据
    train_dataset = "../../data/score_test.txt"
    seq_len = 50
    corpus_lines = 100
    on_memory = True
    print("Loading Train Dataset:", train_dataset)
    train_dataset = BERTDataset(train_dataset, vocab, seq_len=seq_len,
                                corpus_lines=corpus_lines, on_memory=on_memory)

    batch_size = 64
    num_workers = 2
    print("Creating Dataloader")
    train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)

    # 输入到模型的最原始的数据
    for i, data in enumerate(train_data_loader):
        # 0. 批数据将被发送到设备（GPU或cpu）
        data = {key: value for key, value in data.items()}
        print(data)
        break


if __name__ == '__main__':
    main()
